local super = require "Object"

Sequence = super:new()

local tonumber = tonumber

function Sequence:new(params)
    self = super.new(self)
    
    self._length = params.length or function() return 0 end
    self._getter = params.getter or function(index) end
    self._each = params.each
    self._cached = params.cached
    
    return self
end

function Sequence:newWithArray(array, length)
    array = array or {}
    length = length or #array
    return Sequence:new{
        getter = function(index)
            return array[index]
        end,
        length = function()
            return length
        end,
        cached = true
    }
end

function Sequence:newWithScalar(value)
    return Sequence:new{
        getter = function(index)
            return value
        end,
        length = function()
            return math.huge
        end,
        cached = true
    }
end

function Sequence:newWithSequenceList(sequences)
    local valueArgumentNames = {}
    local valueNames = {}
    local valueArguments = {}
    for index = 1, #sequences do
        sequences[index] = sequences[index]:cachedCopy()
        if sequences[index]:isScalar() then
            local value = sequences[index]:getValue(1)
            if type(value) == 'number' or type(value) == 'boolean' or type(value) == 'nil' then
                valueNames[#valueNames + 1] = tostring(value)
            else
                local valueName = 'value' .. index
                valueArgumentNames[#valueArgumentNames + 1] = valueName
                valueNames[#valueNames + 1] = valueName
                valueArguments[#valueArguments + 1] = value
            end
        else
            local getterName = 'value' .. index
            valueArgumentNames[#valueArgumentNames + 1] = getterName
            valueNames[#valueNames + 1] = getterName .. '(index)'
            valueArguments[#valueArguments + 1] = sequences[index]._getter
        end
    end
    valueArgumentNames = table.concat(valueArgumentNames, ', ')
    valueNames = table.concat(valueNames, ', ')
    
    local listGetter = loadstring([[
        return function(]] .. valueArgumentNames .. [[)
            return function(index)
                return ]] .. valueNames .. [[
            end
        end
    ]])()(unpack(valueArguments))
    
    local listEach = loadstring([[
        return function(]] .. valueArgumentNames .. [[)
            return function(self, func)
                for index = 1, self:maxIndex() do
                    func(]] .. valueNames .. [[)
                end
            end
        end
    ]])()(unpack(valueArguments))
    
    return Sequence:new{
        getter = listGetter,
        each = listEach,
        length = function()
            local length = math.huge
            for index = 1, #sequences do
                length = math.min(length, sequences[index]:length())
            end
            return length
        end,
        cached = true
    }
end

function Sequence:getValue(index)
    return self._getter(index)
end

function Sequence:length()
    return self._length()
end

function Sequence:maxIndex()
    local length = self:length()
    if length == math.huge then
        length = 1
    end
    return length
end

function Sequence:isScalar()
    return (self:length() == math.huge)
end

function Sequence:isEmpty()
    return (self:length() == 0)
end

function Sequence:toArray()
    local result = {}
    local getter = self._getter
    for index = 1, self:maxIndex() do
        result[index] = getter(index)
    end
    return result
end

function Sequence:iter()
    local getter = self._getter
    local length = self:maxIndex()
    local index = 1
    return function()
        if index <= length then
            local value = getter(index)
            index = index + 1
            return index - 1, value
        end
    end
end

function Sequence:each(func)
    if self._each then
        self._each(self, func)
    else
        local getter = self._getter
        for index = 1, self:maxIndex() do
            func(getter(index))
        end
    end
end

function Sequence:pairIter(other)
    local length
    if self:isScalar() and other:isScalar() then
        length = 1
    elseif self:isScalar() then
        length = other:length()
    elseif other:isScalar() then
        length = self:length()
    else
        length = math.min(self:length(), other:length())
    end
    local getter1, getter2 = self._getter, other._getter
    local index = 1
    return function()
        if index <= length then
            local value1 = getter1(index)
            local value2 = getter2(index)
            index = index + 1
            return index - 1, value1, value2
        end
    end
end

function Sequence:cachedCopy()
    if self._cached then
        return self
    end
    local result
    if self:isScalar() then
        result = Sequence:newWithScalar(self:getValue(1))
    else
        local getter = self._getter
        local array = {}
        for index = 1, self:length() do
            array[index] = getter(index)
        end
        result = Sequence:newWithArray(array, self:length())
    end
    return result
end

function Sequence:join(other)
    local length = self:maxIndex()
    local otherLength = other:maxIndex()
    local getter1, getter2 = self._getter, other._getter
    return Sequence:new{
        getter = function(index)
            if index <= length then
                return getter1(index)
            elseif index - length <= otherLength then
                return getter2(index - length)
            end
        end,
        length = function()
            return length + otherLength
        end,
    }
end

function Sequence:map(func)
    local getter = self._getter
    return Sequence:new{
        getter = function(index)
            return func(getter(index))
        end,
        length = function()
            return self:length()
        end,
    }
end

function Sequence:map2(func, other)
    local getter1, getter2 = self._getter, other._getter
    return Sequence:new{
        getter = function(index)
            return func(getter1(index), getter2(index))
        end,
        length = function()
            return math.min(self:length(), other:length())
        end,
    }
end

function Sequence:multiMap(func, ...)
    local sequences = {self, ...}
    return Sequence:new{
        getter = function(index)
            local function sequenceUnpack(sequenceIndex)
                sequenceIndex = sequenceIndex or 1
                if sequenceIndex <= #sequences then
                    return sequences[sequenceIndex]:getValue(index), sequenceUnpack(sequenceIndex + 1)
                end
            end
            return func(sequenceUnpack())
        end,
        length = function()
            local length = math.huge
            for index = 1, #sequences do
                length = math.min(length, sequences[index]:length())
            end
            return length
        end,
    }
end

function Sequence:ifMap(ifTrue, ifFalse)
    local getter, trueGetter, falseGetter = self._getter, ifTrue._getter, ifFalse._getter
    return Sequence:new{
        getter = function(index)
            local condition = getter(index)
            if condition == nil then
                return nil
            elseif condition then
                return trueGetter(index)
            else
                return falseGetter(index)
            end
        end,
        length = function()
            return math.min(self:length(), math.min(ifTrue:length(), ifFalse:length()))
        end,
    }
end

function Sequence:filter(other)
    local getter, conditionGetter = self._getter, other._getter
    return Sequence:new{
        getter = function(index)
            if conditionGetter(index) then
                return getter(index)
            end
        end,
        length = function()
            return math.min(self:length(), other:length())
        end,
    }
end

function Sequence:compact()
    if self:isScalar() then
        return self
    end
    local getter = self._getter
    local array = {}
    for index = 1, self:length() do
        array[#array + 1] = getter(index)
    end
    return Sequence:newWithArray(array)
end

function Sequence:unique()
    if self:isScalar() then
        return self
    end
    local getter = self._getter
    local values = {}
    local hashes = {}
    local array = {}
    for index = 1, self:length() do
        local value = getter(index)
        if value ~= nil then
            if type(value) == 'userdata' and type(value.hash) == 'function' then
                local hash = value:hash()
                if not hashes[hash] then
                    array[index] = value
                    hashes[hash] = true
                end
            elseif not values[value] then
                array[index] = value
                values[value] = true
            end
        end
    end
    return Sequence:newWithArray(array, self:length())
end

local function sortIndexes(sequence, min, max, destArray, comparator)
    local nils = 0
    if min == max then
        destArray[min] = min
        if sequence:getValue(min) == nil then
            nils = 1
        end
        return nils
    end
    local mid = math.floor((min + max) / 2)
    nils = nils + sortIndexes(sequence, min, mid, destArray, comparator)
    nils = nils + sortIndexes(sequence, mid + 1, max, destArray, comparator)
    local tempArray = {}
    local min1, min2 = min, mid + 1
    for index = 1, max - min + 1 do
        if min1 > mid then
            tempArray[index] = destArray[min2]
            min2 = min2 + 1
        elseif min2 > max then
            tempArray[index] = destArray[min1]
            min1 = min1 + 1
        elseif comparator(sequence:getValue(destArray[min1]), sequence:getValue(destArray[min2])) <= 0 then
            tempArray[index] = destArray[min1]
            min1 = min1 + 1
        else
            tempArray[index] = destArray[min2]
            min2 = min2 + 1
        end
    end
    for index = 1, max - min + 1 do
        destArray[min + index - 1] = tempArray[index]
    end
    return nils
end

function Sequence:sort(other, descending)
    local length = other:maxIndex()
    if length == 0 then
        return Sequence:newWithArray({})
    end
    local indexes = {}
    local ascendingComparator = compare
    local descendingComparator = function(a, b)
        if a ~= nil and b ~= nil then
            return -ascendingComparator(a, b)
        elseif a ~= nil then
            return -1
        elseif b ~= nil then
            return 1
        else
            return 0
        end
    end
    local comparator = (descending and descendingComparator) or ascendingComparator
    local nils = sortIndexes(other, 1, length, indexes, comparator)
    return self:nth(Sequence:newWithArray(indexes, length - nils))
end

function Sequence:rank()
    local length = self:maxIndex()
    if length == 0 then
        return Sequence:newWithArray({})
    end
    local getter = self._getter
    local indexes = {}
    local nils = sortIndexes(self, 1, length, indexes, compare)
    local ranks = {}
    local previousValue, previousRank
    for index = 1, length do
        local value = getter(indexes[index])
        if value ~= previousValue then
            previousValue = value
            previousRank = index
        end
        if value ~= nil then
            ranks[indexes[index]] = previousRank
        end
    end
    return Sequence:newWithArray(ranks, length)
end

function Sequence:reverse(sequence)
    return Sequence:new{
        getter = function(index)
            return self:getValue(1 + self:length() - index)
        end,
        length = function()
            return self:length()
        end,
    }
end

function Sequence:reduce(identity, func)
    local result = identity
    for _, value in self:iter() do
        result = func(result, value)
    end
    return result
end

function Sequence:reduceRange(func, rangeMin, rangeMax)
    if self then
        local getter = self._getter
        local maxIndex = self:maxIndex()
        rangeMin, rangeMax = tonumber(rangeMin), tonumber(rangeMax)
        rangeMin = (rangeMin and rangeMin >= 1 and math.ceil(rangeMin)) or 1
        rangeMax = (rangeMax and rangeMax <= maxIndex and math.floor(rangeMax)) or maxIndex
        for index = rangeMin, rangeMax do
            func(getter(index))
        end
    end
end

function Sequence:nth(other)
    if other:isScalar() then
        local index = tonumber(other:getValue(1))
        local value
        if index and index > 0 and index <= self:length() and index == math.floor(index) then
            value = self:getValue(index)
        end
        return Sequence:newWithScalar(value)
    end
    local array = {}
    local length = other:length()
    for index = 1, length do
        local lookupIndex = tonumber(other:getValue(index))
        if lookupIndex and lookupIndex > 0 and lookupIndex <= self:length() and lookupIndex == math.floor(lookupIndex) then
            array[index] = self:getValue(lookupIndex)
        end
    end
    return Sequence:newWithArray(array, length)
end

function Sequence:first()
    if self:isScalar() then
        return self:getValue(1)
    else
        for _, value in self:iter() do
            if value ~= nil then
                return value
            end
        end
    end
end

function Sequence:last()
    return self:reverse():first()
end

function Sequence:shift(indexDelta)
    return Sequence:new{
        getter = function(index)
            if index <= self:length() then
                return self:getValue(index + indexDelta)
            end
        end,
        length = function()
            return self:length()
        end,
    }
end

local numberMap1 = function(self, func)
    if not self then
        return
    end
    return self:map(function(value)
        value = tonumber(value)
        if value ~= nil then
            return func(value)
        end
    end)
end

local map2 = function(self, func, other)
    if not self or not other then
        return
    end
    return self:map2(function(value1, value2)
        if value1 ~= nil and value2 ~= nil then
            return func(value1, value2)
        end
    end, other)
end

local numberMap2 = function(self, func, other)
    if not self or not other then
        return
    end
    return self:map2(function(value1, value2)
        value1, value2 = tonumber(value1), tonumber(value2)
        if value1 ~= nil and value2 ~= nil then
            return func(value1, value2)
        end
    end, other)
end

local comparableMap2 = function(self, func, other)
    if not self or not other then
        return
    end
    return self:multiMap(function(value1, value2)
        if type(value1) == 'userdata' and type(value2) == 'userdata' then
            if value1.__eq and value1.__eq == value2.__eq then
                return func(value1, value2)
            end
        else
            value1, value2 = tonumber(value1), tonumber(value2)
            if value1 ~= nil and value2 ~= nil then
                return func(value1, value2)
            end
        end
    end, other)
end

local stringMap2 = function(self, func, other)
    if not self or not other then
        return
    end
    return self:map2(function(value1, value2)
        value1, value2 = todisplaystring(value1), todisplaystring(value2)
        if value1 ~= nil and value2 ~= nil then
            return func(value1, value2)
        end
    end, other)
end

function Sequence:__add(other)
    return numberMap2(self, function(a, b) return a+b end, other)
end

function Sequence:__sub(other)
    return numberMap2(self, function(a, b) return a-b end, other)
end

function Sequence:__unm()
    return numberMap1(self, function(a) return -a end)
end

function Sequence:__mul(other)
    return numberMap2(self, function(a, b) return a*b end, other)
end

function Sequence:__div(other)
    return numberMap2(self, function(a, b) return a/b end, other)
end

function Sequence:__mod(other)
    return numberMap2(self, function(a, b) return a%b end, other)
end

function Sequence:__pow(other)
    return numberMap2(self, function(a, b) return a^b end, other)
end

function Sequence:__concat(other)
    return stringMap2(self, function(a, b) return a .. b end, other)
end

-- NOTE: We can't use Lua's built-in metamethods __eq, etc., because Lua converts the resulting (Sequence) values to booleans.

function Sequence:____eq(other)
    return map2(self, function(a, b) return a == b end, other)
end

function Sequence:____ne(other)
    return map2(self, function(a, b) return a ~= b end, other)
end

function Sequence:____lt(other)
    return comparableMap2(self, function(a, b) return a < b end, other)
end

function Sequence:____gt(other)
    return comparableMap2(self, function(a, b) return a > b end, other)
end

function Sequence:____le(other)
    return comparableMap2(self, function(a, b) return a <= b end, other)
end

function Sequence:____ge(other)
    return comparableMap2(self, function(a, b) return a >= b end, other)
end

return Sequence
